In [92]:
import sys
sys.setrecursionlimit(10000)
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

import os
os.environ['GNUMPY_IMPLICIT_CONVERSION'] = 'ignore'
print os.environ.get('GNUMPY_IMPLICIT_CONVERSION')
ignore

In [132]:
import cPickle
import gzip

from breze.learn.data import one_hot
from breze.learn.base import cast_array_to_local_type
from breze.learn.utils import tile_raster_images

import climin.stops
import climin.initialize
from climin import mathadapt as ma

from breze.learn import hvi
from breze.learn.hvi import HmcViModel
from breze.learn.hvi.energies import (NormalGaussKinEnergyMixin, DiagGaussKinEnergyMixin)
from breze.learn.hvi.inversemodels import MlpGaussInvModelMixin

from matplotlib import pyplot as plt
from matplotlib import cm

import numpy as np

#import fasttsne

from IPython.html import widgets
%matplotlib inline

import theano
theano.config.compute_test_value = 'ignore'#'raise'
from theano import (tensor as T, clone)
In [94]:
datafile = '../mnist.pkl.gz'
# Load data.                                                                                                   

with gzip.open(datafile,'rb') as f:                                                                        
    train_set, val_set, test_set = cPickle.load(f)                                                       

X, Z = train_set                                                                                               
VX, VZ = val_set
TX, TZ = test_set

Z = one_hot(Z, 10)
VZ = one_hot(VZ, 10)
TZ = one_hot(TZ, 10)

X_no_bin = X
VX_no_bin = VX
TX_no_bin = TX

# binarize the MNIST data
np.random.seed(0)
X  = np.random.binomial(1, X) * 1.0
VX = np.random.binomial(1, VX) * 1.0
TX = np.random.binomial(1, TX) * 1.0

image_dims = 28, 28

X_np, Z_np, VX_np, VZ_np, TX_np, TZ_np, X_no_bin_np, VX_no_bin_np, TX_no_bin_np = X, Z, VX, VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin
X, Z, VX, VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin = [cast_array_to_local_type(i) 
                                                        for i in (X, Z, VX,VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin)]
print X.shape
(50000L, 784L)

In [95]:
fig, ax = plt.subplots(figsize=(9, 9))

img = tile_raster_images(X_np[:64], image_dims, (8, 8), (1, 1))
ax.imshow(img, cmap=cm.binary)
Out[95]:
<matplotlib.image.AxesImage at 0x6832ebe0>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [5]:
fast_dropout = False

if fast_dropout:
    class MyHmcViModel(HmcViModel, 
                   hvi.FastDropoutMlpBernoulliVisibleVAEMixin, 
                   hvi.FastDropoutMlpGaussLatentVAEMixin, 
                   DiagGaussKinEnergyMixin,
                   MlpGaussInvModelMixin):
        pass

    kwargs = {
        'p_dropout_inpt': .1,
        'p_dropout_hiddens': [.2, .2],
    }

    print 'yeah'

else:
    class MyHmcViModel(HmcViModel, 
                   hvi.MlpBernoulliVisibleVAEMixin, 
                   hvi.MlpGaussLatentVAEMixin, 
                   DiagGaussKinEnergyMixin,
                   MlpGaussInvModelMixin):
        pass
    kwargs = {}


batch_size = 500
#optimizer = 'rmsprop', {'step_rate': 1e-4, 'momentum': 0.95, 'decay': .95, 'offset': 1e-6}
#optimizer = 'adam', {'step_rate': .5, 'momentum': 0.9, 'decay': .95, 'offset': 1e-6}
optimizer = 'adam', {'step_rate': 0.0005}

# This is the number of random variables NOT the size of 
# the sufficient statistics for the random variables.
n_latents = 2
n_hidden = 200

m = MyHmcViModel(X.shape[1], n_latents, 
                 [n_hidden, n_hidden], ['rectifier'] * 2, 
                 [n_hidden, n_hidden], ['rectifier'] * 2, 
                 [n_hidden], ['rectifier'] * 1,
                 n_hmc_steps=3, n_lf_steps=4,
                 n_z_samples=1,
          optimizer=optimizer, batch_size=batch_size, allow_partial_velocity_update=False, perform_acceptance_step=False,
          **kwargs)

#climin.initialize.randomize_normal(m.parameters.data, 0.1, 1e-1)
#m.parameters.__setitem__(m.hmc_sampler.step_size_param, 0.2)
#m.parameters.__setitem__(m.kin_energy.mlp.layers[-1].bias, 1)
\\srv-file.brml.tum.de\nthome\cwolf\code\ml_support\theano\theano\scan_module\scan_perform_ext.py:135: RuntimeWarning: numpy.ndarray size changed, may indicate binary incompatibility
  from scan_perform.scan_perform import *

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [6]:
old_best_params = None
#print m.score(TX)
print m.parameters.data.shape
(554795,)

In [155]:
FILENAME = 'hvi_gen2_recog2_aux1_late2_hid200_fullbin_untrained_new.pkl'

# In[5]:
#old_best_params = None
f = open(FILENAME, 'rb')
np_array = cPickle.load(f)
old_best_params = cast_array_to_local_type(np_array)
f.close()
print old_best_params.shape
(554795,)

In [156]:
m.parameters.data = old_best_params.copy()
#old_best_loss = m.score(VX)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [175]:
print m.score(VX)
print m.score(TX)
garray(180.01100158691406)
garray(181.55076599121094)

In [176]:
print m.parameters.view(m.init_recog.mlp.layers[2].bias)
garray([-0.34611687, -0.09502647, -1.77520561, -2.25207138])

In [174]:
m.parameters.__setitem__(m.hmc_sampler.step_size_param, 0.2)
m.parameters.__setitem__(m.init_recog.mlp.layers[2].bias, cast_array_to_local_type(np.array([-0.34611687, -0.09502647, -1.77520561, -2.25207138])))
m.parameters.__setitem__(m.kin_energy.variance_parameter, cast_array_to_local_type(np.array([-0.7, -0.7])))
In [159]:
print 0.1 * m.parameters.view(m.hmc_sampler.step_size_param) ** 2 + 1e-8
garray([ 0.00400001])

In [87]:
#print m.estimate_nll(TX, 500)
142.505125

In [177]:
print m.score(VX_no_bin)
print m.score(TX_no_bin)
garray(180.8560028076172)
garray(181.33506774902344)

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [178]:
TARGET_FILENAME = 'hvi_gen2_recog2_aux1_late2_hid200_fullbin_3hmc_04lf_pretrained'
FILETYPE_EXTENSION = '.pkl'
old_best_params = None

max_passes = 500
max_iter = max_passes * X.shape[0] / batch_size
n_report = X.shape[0] / batch_size

stop = climin.stops.AfterNIterations(max_iter)
pause = climin.stops.ModuloNIterations(n_report)

# theano.config.optimizer = 'fast_compile'

for i, info in enumerate(m.powerfit((X_no_bin,), (VX,), stop, pause, eval_train_loss=False)):
    print i, info['loss'], info['val_loss']
    if i == 0 and old_best_params is not None:
        if info['best_loss'] > old_best_loss:
            info['best_loss'] = old_best_loss
            info['best_pars'] = old_best_params
    
    if info['best_loss'] == info['val_loss']:
        f = open(TARGET_FILENAME + FILETYPE_EXTENSION, 'wb')
        cPickle.dump(m.parameters.data, f, protocol=cPickle.HIGHEST_PROTOCOL)
        f.close()
0 0.0 135.728134155
1 0.0 137.329177856
2 0.0 138.017166138
3 0.0 136.646469116
4 0.0 137.377563477
5 0.0 136.652114868
6 0.0 153.01802063
7 0.0 150.019454956
8 0.0 149.642501831
9 0.0 151.729827881
10 0.0 152.342514038
11 0.0 149.720474243
12 0.0 153.643264771
13 0.0 151.117156982
14 0.0 150.744064331
15 0.0 149.049194336
16 0.0 148.793258667
17 0.0 147.925292969
18 0.0 147.513595581
19 0.0 147.302276611
20 0.0 147.232162476
21 0.0 145.816650391
22 0.0 146.539520264
23 0.0 145.536254883
24 0.0 145.720733643
25 0.0 145.059616089
26 0.0 145.311889648
27 0.0 145.04649353
28 0.0 144.228439331
29 0.0 143.770324707
30 0.0 143.697341919
31 0.0 143.731170654
32 0.0 143.44291687
33 0.0 143.471740723
34 0.0 142.900009155
35 0.0 142.350875854
36 0.0 142.380889893
37 0.0 141.800552368
38 0.0 141.519180298
39 0.0 141.341293335
40 0.0 140.35017395
41 0.0 140.473068237
42 0.0 140.384048462
43 0.0 140.162384033
44 0.0 140.259841919
45 0.0 139.321960449
46 0.0 139.183731079
47 0.0 139.03187561
48 0.0 139.280929565
49 0.0 138.968795776
50 0.0 139.049636841
51 0.0 138.887680054
52 0.0 139.01965332
53 0.0 138.951660156
54 0.0 138.503189087
55 0.0 139.173416138
56 0.0 138.639862061
57 0.0 138.6484375
58 0.0 138.627075195
59 0.0 138.507644653
60 0.0 138.736877441
61 0.0 138.753189087
62 0.0 138.609146118
63 0.0 138.39402771
64 0.0 138.362640381
65 0.0 138.498168945
66 0.0 138.464202881
67 0.0 138.356246948
68 0.0 138.370864868
69 0.0 138.555526733
70 0.0 138.45362854
71 0.0 138.572341919
72 0.0 138.354385376
73 0.0 138.58241272
74 0.0 137.885742188
75 0.0 138.029144287
76 0.0 138.127670288
77 0.0 138.107589722
78 0.0 137.663986206
79 0.0 137.698440552
80 0.0 137.592102051
81 0.0 137.711334229
82 0.0 137.91394043
83 0.0 137.600524902
84 0.0 137.656997681
85 0.0 137.624450684
86 0.0 137.813796997
87 0.0 137.513748169
88 0.0 137.77897644
89 0.0 137.764343262
90 0.0 137.407089233
91 0.0 137.825714111
92 0.0 137.990081787
93 0.0 137.91368103
94 0.0 138.093017578
95 0.0 137.380493164
96 0.0 137.864257812
97 0.0 137.479293823
98 0.0 137.435699463
99 0.0 137.861541748
100 0.0 137.411636353
101 0.0 137.686218262
102 0.0 137.409866333
103 0.0 137.509750366
104 0.0 137.852874756
105 0.0 137.879989624
106 0.0 138.128616333
107 0.0 137.363525391
108 0.0 137.600372314
109 0.0 138.220703125
110 0.0 137.753143311
111 0.0 137.678619385
112 0.0 137.696716309
113 0.0 137.672607422
114 0.0 137.373077393
115 0.0 137.5128479
116 0.0 137.286941528
117 0.0 137.679748535
118 0.0 137.36428833
119 0.0 137.314468384
120 0.0 137.38104248
121 0.0 137.526016235
122 0.0 137.430465698
123 0.0 137.347396851
124 0.0 137.696746826
125 0.0 137.733139038
126 0.0 137.774627686
127 0.0 137.878189087
128 0.0 137.624130249
129 0.0 137.250778198
130 0.0 137.392944336
131 0.0 137.384368896
132 0.0 137.198165894
133 0.0 137.169769287
134 0.0 137.786819458
135 0.0 137.446395874
136 0.0 137.350341797
137 0.0 137.392425537
138 0.0 137.3775177
139 0.0 137.000854492
140 0.0 137.113494873
141 0.0 137.07144165
142 0.0 137.373977661
143 0.0 137.184127808
144 0.0 137.581985474
145 0.0 137.394241333
146 0.0 137.670120239
147 0.0 137.564163208
148 0.0 137.223693848
149 0.0 137.328643799
150 0.0 137.390914917
151 0.0 137.458740234
152 0.0 137.079025269
153 0.0 137.16633606
154 0.0 137.43548584
155 0.0 136.948577881
156 0.0 136.953140259
157 0.0 137.353591919
158 0.0 137.050415039
159 0.0 137.201446533
160 0.0 137.663269043
161 0.0 136.88394165
162 0.0 137.150665283
163 0.0 137.064315796
164 0.0 137.328094482
165 0.0 137.602966309
166 0.0 137.516189575
167 0.0 137.147979736
168 0.0 137.076721191
169 0.0 137.071975708
170 0.0 137.451965332
171 0.0 137.146316528
172 0.0 137.272628784
173 0.0 136.682037354
174 0.0 137.126617432
175 0.0 136.840194702
176 0.0 137.065139771
177 0.0 137.001296997
178 0.0 136.840591431
179 0.0 137.178741455
180 0.0 137.871871948
181 0.0 136.901550293
182 0.0 136.842498779
183 0.0 136.687423706
184 0.0 136.625717163
185 0.0 136.859802246
186 0.0 136.922576904
187 0.0 136.894073486
188 0.0 137.056594849
189 0.0 136.801452637
190 0.0 137.107223511
191 0.0 136.855926514
192 0.0 137.393875122
193 0.0 136.94519043
194 0.0 136.985870361
195 0.0 137.314117432
196 0.0 136.936645508
197 0.0 136.916320801
198 0.0 136.924301147
199 0.0 136.762496948
200 0.0 136.897750854
201 0.0 137.335144043
202 0.0 136.956573486
203 0.0 137.078491211
204 0.0 136.973602295
205 0.0 136.964614868
206 0.0 136.970565796
207 0.0 136.638427734
208 0.0 136.635726929
209 0.0 136.673477173
210 0.0 136.563415527
211 0.0 136.951095581
212 0.0 136.630477905
213 0.0 136.468215942
214 0.0 136.792373657
215 0.0 136.596221924
216 0.0 136.492889404
217 0.0 136.521453857
218 0.0 136.475479126
219 0.0 136.51789856
220 0.0 136.310775757
221 0.0 136.413513184
222 0.0 136.455886841
223 0.0 136.547439575
224 0.0 136.204315186
225 0.0 136.445800781
226 0.0 136.349899292
227 0.0 136.637313843
228 0.0 136.235183716
229 0.0 136.680175781
230 0.0 136.320465088
231 0.0 136.053344727
232 0.0 136.484375
233 0.0 136.282928467
234 0.0 136.131164551
235 0.0 135.892929077
236 0.0 135.982620239
237 0.0 136.212539673
238 0.0 136.251678467
239 0.0 136.282577515
240 0.0 136.129333496
241 0.0 136.163864136
242 0.0 135.713638306
243 0.0 135.919143677
244 0.0 136.183807373
245 0.0 136.197174072
246 0.0 136.341629028
247 0.0 136.092941284
248 0.0 136.132385254
249 0.0 136.103820801
250 0.0 135.899505615
251 0.0 135.970214844
252 0.0 136.150039673
253 0.0 135.954360962
254 0.0 136.552215576
255 0.0 136.070327759
256 0.0 136.158843994
257 0.0 136.106018066
258 0.0 136.054458618
259 0.0 136.117645264
260 0.0 136.275512695
261 0.0 136.403656006
262 0.0 136.392745972
263 0.0 136.093826294
264 0.0 136.126998901
265 0.0 136.390151978
266 0.0 136.343322754
267 0.0 136.512329102
268 0.0 136.41494751
269 0.0 136.52305603
270 0.0 136.535949707
271 0.0 137.070144653
272 0.0 136.522583008
273 0.0 136.246627808
274 0.0 136.618850708
275 0.0 136.512893677
276 0.0 136.864364624
277 0.0 136.056945801
278 0.0 136.077377319
279 0.0 135.969146729
280 0.0 136.060089111
281 0.0 136.019943237
282 0.0 136.195678711
283 0.0 136.036865234
284 0.0 135.749603271
285 0.0 136.034973145
286 0.0 136.369888306
287 0.0 136.048797607
288 0.0 136.527130127
289 0.0 136.275848389
290 0.0 136.444351196
291 0.0 136.924621582
292 0.0 136.881820679
293 0.0 136.782974243
294 0.0 136.252624512
295 0.0 136.3777771
296 0.0 136.902603149
297 0.0 136.554840088
298 0.0 136.588470459
299 0.0 136.601196289
300 0.0 136.888656616
301 0.0 136.582519531
302 0.0 136.793319702
303 0.0 136.613372803
304 0.0 136.75553894
305 0.0 136.692443848
306 0.0 136.840576172
307 0.0 136.535766602
308 0.0 136.611877441
309 0.0 136.503494263
310 0.0 136.540100098
311 0.0 136.57800293
312 0.0 136.360275269
313 0.0 136.507766724
314 0.0 136.553451538
315 0.0 136.827423096
316 0.0 136.646148682
317 0.0 136.789489746
318 0.0 136.891876221
319 0.0 136.732040405
320 0.0 137.154678345
321 0.0 136.988494873
322 0.0 136.986251831
323 0.0 137.288955688
324 0.0 137.407302856
325 0.0 137.293289185
326 0.0 137.356491089
327 0.0 137.079299927
328 0.0 137.061004639
329 0.0 137.092544556
330 0.0 137.174118042
331 0.0 137.188842773
332 0.0 137.604370117
333 0.0 137.119491577
334 0.0 137.160339355
335 0.0 137.219589233
336 0.0 137.106628418
337 0.0 137.288482666
338 0.0 137.28112793
339 0.0 137.076370239
340 0.0 137.128265381
341 0.0 137.031265259
342 0.0 137.293991089
343 0.0 136.827072144
344 0.0 137.390075684
345 0.0 137.353164673
346 0.0 137.200454712
347 0.0 137.171722412
348 0.0 137.225738525
349 0.0 137.084854126
350 0.0 137.327667236
351 0.0 137.246017456
352 0.0 137.197998047
353 0.0 137.487091064
354 0.0 137.718399048
355 0.0 137.520645142
356 0.0 137.140869141
357 0.0 137.386306763
358 0.0 138.040664673
359 0.0 137.246627808
360 0.0 137.214233398
361 0.0 137.145721436
362 0.0 137.051177979
363 0.0 136.910705566
364 0.0 137.309356689
365 0.0 137.237915039
366 0.0 137.534790039
367 0.0 137.247192383
368 0.0 136.976547241
369 0.0 137.037475586
370 0.0 136.906967163
371 0.0 137.030990601
372 0.0 137.565322876
373 0.0 137.120697021
374 0.0 137.029754639
375 0.0 137.568939209
376 0.0 137.077392578
377 0.0 136.966796875
378 0.0 137.082839966
379 0.0 137.057800293
380 0.0 137.224975586
381 0.0 137.118301392
382 0.0 137.249099731
383 0.0 136.845321655
384 0.0 136.856445312
385 0.0 136.786361694
386 0.0 136.771652222
387 0.0 136.672348022
388 0.0 136.710601807
389 0.0 136.576919556
390 0.0 137.269073486
391 0.0 136.970153809
392 0.0 137.001785278
393 0.0 137.088165283
394 0.0 137.107696533
395 0.0 137.238540649
396 0.0 137.109939575
397 0.0 136.91166687
398 0.0 137.008972168
399 0.0 137.196624756
400 0.0 136.96194458
401 0.0 137.253067017
402 0.0 136.946304321
403 0.0 136.68522644
404 0.0 137.035293579
405 0.0 137.050720215
406 0.0 137.270751953
407 0.0 137.02166748
408 0.0 136.812438965
409 0.0 136.685348511
410 0.0 136.708984375
411 0.0 136.649505615
412 0.0 136.850830078
413 0.0 136.758773804
414 0.0 136.638504028
415 0.0 137.008743286
416 0.0 136.748016357
417 0.0 136.78453064
418 0.0 137.219177246
419 0.0 137.082122803
420 0.0 137.104614258
421 0.0 137.236312866
422 0.0 137.704818726
423 0.0 137.014724731
424 0.0 137.484451294
425 0.0 137.187133789
426 0.0 136.973739624
427 0.0 137.468139648
428 0.0 137.11807251
429 0.0 136.931137085
430 0.0 136.676635742
431 0.0 136.995925903
432 0.0 136.945281982
433 0.0 136.769302368
434 0.0 136.929321289
435 0.0 136.698364258
436 0.0 136.657165527
437 0.0 137.066085815
438 0.0 136.907714844
439 0.0 137.049835205
440 0.0 136.715118408
441 0.0 136.878540039
442 0.0 136.999526978
443 0.0 136.534912109
444 0.0 136.582946777
445 0.0 136.815811157
446 0.0 136.58883667
447 0.0 136.326522827
448 0.0 136.990203857
449 0.0 136.564666748
450 0.0 137.148147583
451 0.0 136.779174805
452 0.0 136.870544434
453 0.0 136.73727417
454 0.0 136.654769897
455 0.0 136.753967285
456 0.0 136.866867065
457 0.0 136.792678833
458 0.0 136.848922729
459 0.0 137.008392334
460 0.0 136.870819092
461 0.0 136.584701538
462 0.0 136.638076782
463 0.0 136.903427124
464 0.0 136.690597534
465 0.0 136.576217651
466 0.0 137.01399231
467 0.0 136.45741272
468 0.0 136.458099365
469 0.0 136.445571899
470 0.0 136.388977051
471 0.0 136.535400391
472 0.0 136.73614502
473 0.0 136.653915405
474 0.0 136.457077026
475 0.0 136.422317505
476 0.0 136.640945435
477 0.0 136.782501221
478 0.0 137.171142578
479 0.0 136.836364746
480 0.0 136.872772217
481 0.0 137.033615112
482 0.0 136.854873657
483 0.0 137.200057983
484 0.0 136.557189941
485 0.0 136.375305176
486 0.0 136.7993927
487 0.0 136.667449951
488 0.0 136.863723755
489 0.0 136.857620239
490 0.0 136.651565552
491 0.0 136.604202271
492 0.0 136.738327026
493 0.0 136.492034912
494 0.0 136.757827759
495 0.0 136.609115601
496 0.0 136.651397705
497 0.0 136.727813721
498 0.0 136.457397461
499 0.0 136.528152466

In [181]:
#train_result_params = m.parameters.data.copy()
m.parameters.data = info['best_pars']
m.score(VX)
Out[181]:
garray(135.9841766357422)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [182]:
f_z_init_sample = m.function(['inpt'], m.init_recog.sample(), numpy_result=True)
f_z_sample = m.function(['inpt'], m.hmc_sampler.output, numpy_result=True)
f_gen = m.function([m.gen.inpt], m.gen.sample(), numpy_result=True)
f_gen_rate = m.function([m.gen.inpt], m.gen.rate, numpy_result=True)
f_joint_nll = m.function(['inpt'], m.joint_nll, numpy_result=True)
In [183]:
curr_pos = T.matrix('current_position')
curr_vel = T.matrix('current_velocity')
norm_noise = T.matrix('normal_noise')
unif_noise = T.vector('uniform_noise')

new_sampled_vel = m.hmc_sampler.kin_energy.sample(norm_noise)
updated_vel = m.hmc_sampler.partial_vel_constant * curr_vel + m.hmc_sampler.partial_vel_complement * new_sampled_vel
performed_hmc_steps = m.hmc_sampler.perform_hmc_steps(curr_pos, curr_vel)
hmc_step = m.hmc_sampler.hmc_step(curr_pos, curr_vel, np.float32(0), norm_noise, unif_noise)
lf_step_results = m.hmc_sampler.simulate_dynamics(curr_pos, curr_vel, return_full_list=True)

f_pot_en = m.function(['inpt', curr_pos], m.hmc_sampler.eval_pot_energy(curr_pos), numpy_result=True)
f_kin_en = m.function(['inpt', curr_vel], m.kin_energy.nll(curr_vel).sum(-1), numpy_result=True)
f_perform_hmc_steps = m.function(['inpt', curr_pos, curr_vel], 
                                T.concatenate([performed_hmc_steps[0], performed_hmc_steps[1]], axis=1))
f_hmc_step = m.function(['inpt', curr_pos, curr_vel, norm_noise, unif_noise], 
                        T.concatenate([hmc_step[0], hmc_step[1]],axis=1), on_unused_input='warn')
f_kin_energy_sample_from_noise = m.function(['inpt', norm_noise], new_sampled_vel)
f_updated_vel_from_noise = m.function(['inpt', curr_vel, norm_noise], updated_vel)
f_perform_lf_steps = m.function(['inpt', curr_pos, curr_vel],
                               T.concatenate([lf_step_results[0], lf_step_results[1]], axis=0))
In [184]:
f_z_init_mean = m.function(['inpt'], m.init_recog.mean, numpy_result=True)
f_z_init_var = m.function(['inpt'], m.init_recog.var, numpy_result=True)

f_v_init_var = m.function(['inpt'], T.extra_ops.cpu_contiguous(m.kin_energy.var), numpy_result=True)

full_sample = m.hmc_sampler.sample_with_path()
f_full_sample = m.function(['inpt'], T.concatenate([full_sample[0], full_sample[1]], axis=1))
In [217]:
final_pos = T.matrix('final_pos')
final_vel = T.matrix('final_vel')
inpt_replacements = {m.final_vel_model_inpt['position']: final_pos,
                     m.final_vel_model_inpt['time']: T.cast(m.hmc_sampler.n_hmc_steps, dtype='float32')}

final_vel_model_var = clone(m.final_vel_model.var, replace=inpt_replacements)
final_vel_model_mean = clone(m.final_vel_model.mean, replace=inpt_replacements)
final_vel_model_nll = clone(m.final_vel_model.nll(final_vel).sum(-1), replace=inpt_replacements)

f_v_final_var = m.function(['inpt', final_pos], final_vel_model_var, numpy_result=True)
f_v_final_mean = m.function(['inpt', final_pos], final_vel_model_mean, numpy_result=True)
f_v_final_model_nll = m.function(['inpt', final_pos, final_vel], final_vel_model_nll, numpy_result=True)

f_kin_energy_nll = m.function(['inpt'], m.kin_energy.expected_nll, numpy_result=True)
In [186]:
f_init_recog_nll = m.function(['inpt'], m.init_recog.expected_nll.sum(-1), numpy_result=True)
In [187]:
print f_init_recog_nll(VX).mean()
init_var = f_z_init_var(VX)
print init_var.mean()
print init_var.max()
print init_var.min()
-3.33286
0.00295931
0.117392
4.20751e-05

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [188]:
fig, axs = plt.subplots(2, 3, figsize=(27, 18))

### Original data

O = (X_no_bin_np[:64])[:, :784].astype('float32')
img = tile_raster_images(O, image_dims, (8, 8), (1, 1))
axs[0, 0].imshow(img, cmap=cm.binary)

O2 = (X_np[:64])[:, :784].astype('float32')
img = tile_raster_images(O2, image_dims, (8, 8), (1, 1))
axs[1, 0].imshow(img, cmap=cm.binary)

### Reconstruction

#z_sample = f_z_sample((X[:64]))
z_init_sample = cast_array_to_local_type(f_z_init_sample((X[:64])))
z_sample = f_perform_hmc_steps((X[:64]), 
                               z_init_sample, 
                               f_kin_energy_sample_from_noise((X[:64]), 
                                                              cast_array_to_local_type(np.random.normal(size=(64, m.n_latent)).astype('float32')))
                               )[-1, :64, :]

R = f_gen_rate(z_sample)[:, :784].astype('float32')
img = tile_raster_images(R, image_dims, (8, 8), (1, 1))
axs[0, 1].imshow(img, cmap=cm.binary)

Rinit = f_gen_rate(z_init_sample)[:, :784].astype('float32')
img = tile_raster_images(Rinit, image_dims, (8, 8), (1, 1))
axs[0, 2].imshow(img, cmap=cm.binary)

R2 = f_gen(z_sample)[:, :784].astype('float32')
img = tile_raster_images(R2, image_dims, (8, 8), (1, 1))
axs[1, 1].imshow(img, cmap=cm.binary)

Rinit2 = f_gen(z_init_sample)[:, :784].astype('float32')
img = tile_raster_images(Rinit2, image_dims, (8, 8), (1, 1))
axs[1, 2].imshow(img, cmap=cm.binary)
Out[188]:
<matplotlib.image.AxesImage at 0xe1e54dd8>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [189]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))

prior_sample = cast_array_to_local_type(np.random.randn(64, m.n_latent))

S = f_gen_rate(prior_sample)[:, :784].astype('float32')
img = tile_raster_images(S, image_dims, (8, 8), (1, 1))
axs[0].imshow(img, cmap=cm.binary)

S2 = f_gen(prior_sample)[:, :784].astype('float32')
img = tile_raster_images(S2, image_dims, (8, 8), (1, 1))
axs[1].imshow(img, cmap=cm.binary)

#S3 = f_gen_rate(prior_sample)[:, :784].astype('float32')
#img = tile_raster_images(S, image_dims, (8, 8), (1, 1))
#axs[2, 2].imshow(img, cmap=cm.nipy_spectral)
Out[189]:
<matplotlib.image.AxesImage at 0xe5abf6d8>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [190]:
# TODO: Axis titles, plot title, make this work if one selects two dimensions out of more than two (i.e. if n_latent>2)

from scipy.stats import norm as normal_distribution

unit_interval_positions = np.linspace(0.025, 0.975, 20)
positions = normal_distribution.ppf(unit_interval_positions)
print unit_interval_positions
print positions

latent_array = np.zeros((400, 2))

latent_array[:, 1] = -np.repeat(positions, 20)  # because images are filled top -> bottom, left -> right (row by row)
latent_array[:, 0] = np.tile(positions, 20)
        
fig, axs = plt.subplots(1, 1, figsize=(24, 24))

F = f_gen_rate(cast_array_to_local_type(latent_array)).astype('float32')

img = tile_raster_images(F, image_dims, (20, 20), (1, 1))
#axs.imshow(img, cmap=cm.nipy_spectral)
axs.imshow(img, cmap=cm.binary)
[ 0.025  0.075  0.125  0.175  0.225  0.275  0.325  0.375  0.425  0.475
  0.525  0.575  0.625  0.675  0.725  0.775  0.825  0.875  0.925  0.975]
[-1.95996398 -1.43953147 -1.15034938 -0.93458929 -0.75541503 -0.59776013
 -0.45376219 -0.31863936 -0.18911843 -0.06270678  0.06270678  0.18911843
  0.31863936  0.45376219  0.59776013  0.75541503  0.93458929  1.15034938
  1.43953147  1.95996398]

Out[190]:
<matplotlib.image.AxesImage at 0x46787240>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [191]:
L = f_z_sample(X)
L_init = f_z_init_sample(X)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [192]:
dim1 = 0
dim2 = 1
In [193]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
axs[0].scatter(L[:, dim1], L[:, dim2], c=Z[:].argmax(1), lw=0, s=5, alpha=.2)
axs[1].scatter(L_init[:, dim1], L_init[:, dim2], c=Z[:].argmax(1), lw=0, s=5, alpha=.2)

cax = fig.add_axes([0.95, 0.2, 0.02, 0.6])
cax.scatter(np.repeat(0, 10), np.arange(10), c=np.arange(10), lw=0, s=300)
cax.set_xlim(-0.1, 0.1)
cax.set_ylim(-0.5, 9.5)
plt.yticks(np.arange(10))
plt.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')
cax.tick_params(axis='y', colors='white')
for tick in cax.yaxis.get_major_ticks():
    tick.label.set_fontsize(14)
    tick.label.set_color('black')
    
cax.spines['bottom'].set_color('white')
cax.spines['top'].set_color('white') 
cax.spines['right'].set_color('white')
cax.spines['left'].set_color('white')

axs[0].set_title('After HMC steps')
axs[1].set_title('Initial recognition model')

axs[0].set_xlim(-3, 3)
axs[0].set_ylim(-3, 3)
axs[1].set_xlim(-3, 3)
axs[1].set_ylim(-3, 3)
Out[193]:
(-3, 3)
In [194]:
fig, axs = plt.subplots(4, 5, figsize=(20, 16))
colors = cm.jet(np.linspace(0, 1, 10))
for i in range(5):
    axs[0, i].scatter(L_init[Z[:].argmax(1) == i, dim1], L_init[Z[:].argmax(1) == i, dim2], c=colors[i], lw=0, s=5, alpha=.2)
    axs[1, i].scatter(L[Z[:].argmax(1) == i, dim1], L[Z[:].argmax(1) == i, dim2], c=colors[i], lw=0, s=5, alpha=.2)
    axs[0, i].set_title(str(i) + ' before HMC')
    axs[1, i].set_title(str(i) + ' after HMC')
    axs[2, i].scatter(L_init[Z[:].argmax(1) == (5+i), dim1], L_init[Z[:].argmax(1) == (5+i), dim2], c=colors[5+i], lw=0, s=5, alpha=.2)
    axs[3, i].scatter(L[Z[:].argmax(1) == (5+i), dim1], L[Z[:].argmax(1) == (5+i), dim2], c=colors[5+i], lw=0, s=5, alpha=.2)
    axs[2, i].set_title(str(5+i) + ' before HMC')
    axs[3, i].set_title(str(5+i) + ' after HMC')
    for j in range(4):
        axs[j, i].set_xlim(-3, 3)
        axs[j, i].set_ylim(-3, 3)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [195]:
X_index = 0  # index=0 -> 5, index=1 -> 0, index=2 -> 4, index=3 -> 1, index=24 -> underlined 1, index=39 -> ugly 6
num_repeats = 1000

fig, axs = plt.subplots(1, 2, figsize=(6, 3))
img = tile_raster_images(np.array([X[X_index, :]]), image_dims, (1, 1), (1, 1))
axs[0].imshow(img, cmap=cm.binary)
img = tile_raster_images(np.array([X_no_bin[X_index, :]]), image_dims, (1, 1), (1, 1))
axs[1].imshow(img, cmap=cm.binary)
Out[195]:
<matplotlib.image.AxesImage at 0xe66b3048>
In [196]:
repeated_X = cast_array_to_local_type(np.tile(np.array([X[X_index, :]]), (num_repeats, 1)).astype('float32'))

full_sample = f_full_sample(repeated_X).astype('float32')
z_samples = full_sample[:, :num_repeats, :]
v_samples = full_sample[:, num_repeats:, :]

z_sample_final_mean = z_samples[m.n_hmc_steps, :, :].mean(axis=0)
z_sample_final_std = z_samples[m.n_hmc_steps, :, :].std(axis=0)

single_X = cast_array_to_local_type(np.array([X[X_index, :]]).astype('float32'))
init_mean = f_z_init_mean(single_X)[0]
init_var = f_z_init_var(single_X)[0]

init_vel_var = f_v_init_var(single_X)[0]

print 'Posterior distribution statistics'
print
print 'Initial model: - Mean: ' + str(init_mean)
print '               - Var:  ' + str(init_var)
print
print 'Full HVI model: - Mean: ' + str(z_sample_final_mean)
print '                - Var:  ' + str(z_sample_final_std ** 2)
print
print 'Velocity model variance: ' + str(init_vel_var)
Posterior distribution statistics

Initial model: - Mean: [ 0.8341707  0.777879 ]
               - Var:  [ 0.00265309  0.0015152 ]

Full HVI model: - Mean: [ 0.82473713  0.80002463]
                - Var:  [ 0.00046972  0.00100139]

Velocity model variance: [ 0.3143298   0.58953619]

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [197]:
dim1 = 0
dim2 = 1
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [198]:
resolution = 201
lower_dim1_limit = z_sample_final_mean[dim1] - 0.2
upper_dim1_limit = z_sample_final_mean[dim1] + 0.2
lower_dim2_limit = z_sample_final_mean[dim2] - 0.2
upper_dim2_limit = z_sample_final_mean[dim2] + 0.2

number_images_per_axis = 11
latent_array = np.zeros((number_images_per_axis**2, 2))
gap_between_images = (resolution - 1)//(number_images_per_axis - 1)
indices_for_images = np.arange(0, resolution, gap_between_images)

pot_energy_matrix = np.zeros((resolution, resolution), dtype='float32')
x = np.linspace(lower_dim1_limit, upper_dim1_limit, resolution)
y = np.linspace(lower_dim2_limit, upper_dim2_limit, resolution)
for i in range(resolution):
    for j in range(resolution):
        #pos_array = f_z_init_mean(single_X)
        pos_array = np.array([z_sample_final_mean])
        pos_array[0, dim1] = x[i]
        pos_array[0, dim2] = y[j]
        pos_array_of_local_type = cast_array_to_local_type(pos_array)
        pot_energy_matrix[j, i] = f_pot_en(single_X, pos_array_of_local_type)[0]
        if i in indices_for_images and j in indices_for_images:
            latent_array[(i//gap_between_images) + (number_images_per_axis - 1 - j//gap_between_images)*number_images_per_axis , :] = pos_array[0, :]

        
print 'Minimum potential energy (at grid points): ' + str(pot_energy_matrix.min())
print 'Maximum potential energy (at grid points): ' + str(pot_energy_matrix.max())

fig, axs = plt.subplots(1, 2, figsize=(18, 9))
CS = axs[0].contour(x, y, pot_energy_matrix, 20)
plt.clabel(CS, inline=1, fmt='%1.0f', fontsize=10)
axs[0].set_title('Potential energy surface')

F = f_gen_rate(cast_array_to_local_type(latent_array))
img = tile_raster_images(F, image_dims, (number_images_per_axis, number_images_per_axis), (1, 1))
#axs.imshow(img, cmap=cm.nipy_spectral)
axs[1].imshow(img, cmap=cm.binary)
plt.show()
Minimum potential energy (at grid points): 179.986
Maximum potential energy (at grid points): 297.216

In [199]:
resolution = 200
underlying_variance = f_v_init_var(single_X)
velocity_range_for_images = 10.0 * np.sqrt(underlying_variance[0, :])
lower_dim1_limit = np.around(- velocity_range_for_images[dim1])
upper_dim1_limit = np.around(  velocity_range_for_images[dim1])
lower_dim2_limit = np.around(- velocity_range_for_images[dim2])
upper_dim2_limit = np.around(  velocity_range_for_images[dim2])

kin_energy_matrix = np.zeros((resolution, resolution), dtype='float32')
kin_x = np.linspace(lower_dim1_limit, upper_dim1_limit, resolution)
kin_y = np.linspace(lower_dim2_limit, upper_dim2_limit, resolution)
for i in range(resolution):
    for j in range(resolution):
        vel_array = np.zeros((1, m.n_latent)).astype('float32')
        vel_array[0, dim1] = kin_x[i]
        vel_array[0, dim2] = kin_y[j]
        vel_array_of_local_type = cast_array_to_local_type(vel_array)
        kin_energy_matrix[j, i] = f_kin_en(single_X, vel_array_of_local_type)

print 'Minimum kinetic energy (at grid points): ' + str(kin_energy_matrix.min())
print 'Maximum kinetic energy (at grid points): ' + str(kin_energy_matrix.max())

fig, ax = plt.subplots(1, 1, figsize=(9, 9))
CS = ax.contour(kin_x, kin_y, kin_energy_matrix)
plt.clabel(CS, inline=1, fmt='%1.1f', fontsize=10)
ax.set_title('Kinetic energy surface')
plt.show()
Minimum kinetic energy (at grid points): 0.997828
Maximum kinetic energy (at grid points): 112.54

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [203]:
fig, axs = plt.subplots(m.n_hmc_steps + 1, 3, figsize=(18, (m.n_hmc_steps + 1) * 6))
colors = cm.jet(np.linspace(0, 1, 10))

#contour_levels = (198, 200, 202, 204, 206, 208, 210)
#contour_levels = (130, 140, 150, 160, 180, 200, 240, 280)
#contour_levels = (100, 102, 104, 106, 108, 110, 115, 120, 125, 130)
#contour_levels = (400, 402, 404, 406, 408, 410, 412, 416, 420)
#contour_levels = (106, 108, 110, 112, 114, 116, 118, 120, 124, 128)
contour_levels = (160, 165, 170, 175, 180, 185, 190, 195, 200, 210, 220, 230, 240, 250, 270, 300)
#contour_levels = (174, 175, 176, 177, 178, 180, 182, 184, 186, 190, 200)
#contour_levels = (59, 61, 63, 65, 67, 69, 71, 73, 75, 80, 85, 90)

vel_contour_levels = np.linspace(2.0, 70.0, 18)
#CS0 = axs[0, 0].contourf(x, y, pot_energy_matrix, np.linspace(155, 240, 500))

def colour_for_z_samples(samples):
    mean = samples.mean(axis=0)
    mean1 = mean[dim1]
    mean2 = mean[dim2]
    colour = np.zeros_like(samples[:, 0])
    colour[np.logical_and(samples[:, dim1] < mean1,  samples[:, dim2] < mean2)] = 0
    colour[np.logical_and(samples[:, dim1] < mean1,  samples[:, dim2] >= mean2)] = 2
    colour[np.logical_and(samples[:, dim1] >= mean1, samples[:, dim2] < mean2)] = 4
    colour[np.logical_and(samples[:, dim1] >= mean1, samples[:, dim2] >= mean2)] = 7
    colour[((samples[:, dim1] - mean1) ** 2 + (samples[:, dim2] - mean2) ** 2) < 1e-5] = 9
    return colour.astype('int32')

colour = colour_for_z_samples(z_samples[m.n_hmc_steps,:,:])
print v_samples[m.n_hmc_steps, colour == 0, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 2, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 4, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 7, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 9, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 0, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 2, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 4, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 7, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 9, :].var(axis=0)

for i in range(m.n_hmc_steps + 1):
    CS = axs[i, 0].contour(x, y, pot_energy_matrix, contour_levels)
    plt.clabel(CS, inline=1, fmt='%1.0f', fontsize=10)
    axs[i, 0].scatter(z_samples[i,:,dim1], z_samples[i,:,dim2], c=colors[colour_for_z_samples(z_samples[i,:,:])], s=20, alpha=.3, lw=0)
    
    CS_vel = axs[i, 1].contour(kin_x, kin_y, kin_energy_matrix, vel_contour_levels)
    plt.clabel(CS_vel, inline=1, fmt='%1.1f', fontsize=10)
    axs[i, 1].scatter(v_samples[i,:,dim1], v_samples[i,:,dim2], c=colors[colour_for_z_samples(z_samples[i,:,:])], s=20, alpha=.3, lw=0)
    
    pot_energy_distrib = f_pot_en(repeated_X, cast_array_to_local_type(z_samples[i, :, :]))
    if i == 0:
        max_x_value_for_hist = pot_energy_distrib.max() + 5
        min_x_value_for_hist = np.floor(pot_energy_matrix.min()) -5
    pot_energy_distrib_mean = pot_energy_distrib.mean()
    axs[i, 2].hist(pot_energy_distrib, 30, normed=1, range=(min_x_value_for_hist, max_x_value_for_hist))
    axs[i, 2].autoscale(enable=False, axis='both')
    axs[i, 2].axvline(pot_energy_distrib_mean, color='r', linestyle='dashed', linewidth=2)
    axs[i, 2].set_xlim(min_x_value_for_hist, max_x_value_for_hist)
    axs[i, 2].text(pot_energy_distrib_mean + 1.0, 0.8*axs[i, 2].get_ylim()[1], 'Mean: ' + str(pot_energy_distrib_mean))
    axs[i, 1].set_xlim(-velocity_range_for_images[dim1], velocity_range_for_images[dim1])
    axs[i, 1].set_ylim(-velocity_range_for_images[dim2], velocity_range_for_images[dim2])

axs[0, 0].scatter(f_z_init_mean(single_X)[0, dim1], f_z_init_mean(single_X)[0, dim2], c='black', s=20)

plt.show()
[ 0.15367848  0.43081975]
[ 0.03977635  0.31012052]
[ 0.1896529   0.39525488]
[ 0.11319306  0.27881271]
[ 0.00642857  0.32013533]
[ 0.34421238  0.7316407 ]
[ 0.35730013  0.71919554]
[ 0.40691343  0.5457229 ]
[ 0.35922599  0.59237504]
[ 0.49072871  0.16278465]

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [204]:
np.random.seed(1)

velocity_noise = cast_array_to_local_type(np.random.normal(size=(m.n_hmc_steps, 1, m.n_latent)))
#velocity_noise = np.zeros_like(velocity_noise)

init_pos = f_z_init_mean(single_X) + np.array([0.0, 0.1])
init_vel = f_kin_energy_sample_from_noise(single_X, velocity_noise[0])

num_vels_per_hmc = (m.n_lf_steps + 2)
position_array = np.zeros((m.n_hmc_steps * m.n_lf_steps + 1, m.n_latent))
position_array[0] = init_pos
velocity_array = np.zeros((m.n_hmc_steps * num_vels_per_hmc, m.n_latent))
velocity_array[0] = ma.assert_numpy(init_vel)

for hmc_num in range(m.n_hmc_steps):
    if hmc_num == 0:
        curr_pos = cast_array_to_local_type(init_pos)
        curr_vel = init_vel
    else:
        curr_vel = f_updated_vel_from_noise(single_X, curr_vel, velocity_noise[hmc_num])
        velocity_array[hmc_num * (m.n_lf_steps + 2)] = ma.assert_numpy(curr_vel)
    
    lf_step_results = f_perform_lf_steps(single_X, curr_pos, curr_vel)
    pos_steps = lf_step_results[:m.n_lf_steps]
    vel_half_steps_and_final = lf_step_results[m.n_lf_steps:]
    final_vel = lf_step_results[-1]
    final_pos = pos_steps[-1]
    
    position_array[hmc_num * m.n_lf_steps + 1: (hmc_num + 1)*m.n_lf_steps + 1] = ma.assert_numpy(pos_steps[:, 0, :])
    velocity_array[hmc_num * num_vels_per_hmc + 1: (hmc_num + 1) * num_vels_per_hmc] = ma.assert_numpy(vel_half_steps_and_final[:, 0, :])
    
    curr_pos = final_pos
    curr_vel = final_vel
In [205]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
step_color = cm.jet(np.linspace(0, 1, position_array.shape[0]))
CS = axs[0].contour(x, y, pot_energy_matrix, contour_levels)
CS_vel = axs[1].contour(kin_x, kin_y, kin_energy_matrix, vel_contour_levels)
hmc_step_indices = np.arange(0, position_array.shape[0], m.n_lf_steps)
size_array = 40*np.ones((position_array.shape[0],))
size_array[hmc_step_indices] = 100
axs[0].scatter(position_array[:, dim1], position_array[:, dim2], c=step_color, lw=1, s=size_array)
axs[1].set_color_cycle(step_color)

for hmc_num in range(m.n_hmc_steps):
    curr_vel_range = np.arange(num_vels_per_hmc * hmc_num, num_vels_per_hmc * (hmc_num + 1) - 2)
    init_vel_ind = hmc_num * num_vels_per_hmc
    final_vel_ind = (hmc_num + 1) * num_vels_per_hmc - 1
    curr_index = hmc_step_indices[hmc_num]
    next_index = hmc_step_indices[hmc_num + 1]
    for j in curr_vel_range:
        axs[1].plot(velocity_array[j:j+2, dim1], velocity_array[j:j+2, dim2], lw=2)
    axs[1].scatter(velocity_array[init_vel_ind, dim1], velocity_array[init_vel_ind, dim2], c=step_color[curr_index], lw=0, s=100)
    axs[1].scatter(velocity_array[final_vel_ind, dim1], velocity_array[final_vel_ind, dim2], c=step_color[next_index], lw=0, s=100)

for hmc_num in range(m.n_hmc_steps):
    final_vel_ind = (hmc_num + 1) * num_vels_per_hmc - 1
    next_index = hmc_step_indices[hmc_num + 1]
    axs[1].plot(velocity_array[final_vel_ind-1:final_vel_ind+1, dim1], velocity_array[final_vel_ind-1:final_vel_ind+1, dim2], lw=2, c=step_color[next_index])
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [207]:
variation_start = z_sample_final_mean - 2*z_sample_final_std
variation_end = z_sample_final_mean + 2*z_sample_final_std

final_vel_model_mean_output = np.zeros((m.n_latent, num_repeats, m.n_latent))
final_vel_model_var_output = np.zeros((m.n_latent, num_repeats, m.n_latent))

for variation_dim in range(m.n_latent):
    z_variation = np.linspace(variation_start[variation_dim], variation_end[variation_dim], num_repeats)
    sample_array = np.tile(z_sample_final_mean, (num_repeats, 1))
    sample_array[:, variation_dim] = z_variation
    final_vel_model_mean_output[variation_dim] = f_v_final_mean(repeated_X, cast_array_to_local_type(sample_array))
    final_vel_model_var_output[variation_dim] = f_v_final_var(repeated_X, cast_array_to_local_type(sample_array))
In [208]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
axs[0].scatter(final_vel_model_mean_output[:, :, dim1], 
           final_vel_model_mean_output[:, :, dim2],  
           c=np.transpose(np.tile(np.linspace(0,m.n_latent-1,m.n_latent), (num_repeats, 1))), 
           lw=0, s=5)
axs[1].scatter(final_vel_model_var_output[:, :, dim1], 
           final_vel_model_var_output[:, :, dim2],  
           c=np.transpose(np.tile(np.linspace(0,m.n_latent-1,m.n_latent), (num_repeats, 1))), 
           lw=0, s=5)

plt.show()
In [212]:
final_z_samples = cast_array_to_local_type(z_samples[m.n_hmc_steps, :, :])
final_v_samples = cast_array_to_local_type(v_samples[m.n_hmc_steps, :, :])
final_vel_mean = f_v_final_mean(repeated_X, final_z_samples)
final_vel_var = f_v_final_var(repeated_X, final_z_samples)
final_vel_nll = f_v_final_model_nll(repeated_X, final_z_samples, final_v_samples)
2.31926

In [218]:
print f_kin_energy_nll(single_X).sum(-1)

print final_vel_nll.mean()
print final_vel_nll.min()
print final_vel_nll.max()
[ 1.99501109]
2.31926
1.69163
6.26161

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [None]:
fig, axs = plt.subplots(4, 2, figsize=(18, 36))
# TODO: Analysis of how final_vel_mean and final_vel_var depend on z (since they all share the same x)

print z_samples[3, :, :].mean(axis=0)
print z_samples[3, :, :].var(axis=0)
print v_samples[3, :, :].mean(axis=0)
print v_samples[3, :, :].var(axis=0)
print f_v_init_var(np.array([X[X_index, :]]))

print final_vel_nll.mean()
plt.boxplot(final_vel_nll, whis=1)
plt.show()
In [None]:
centers = np.zeros((10,n_latents))
stddevs = np.zeros((10,n_latents))
centers_init = np.zeros((10,n_latents))
stddevs_init = np.zeros((10,n_latents))
for i in range(10):
    Li = f_z_sample(X[Z.argmax(1) == i])
    centers[i] = Li.mean(axis=0)
    stddevs[i] = np.std(Li, axis=0)
    
    Li_init = f_z_init_sample(X[Z.argmax(1) == i])
    centers_init[i] = Li_init.mean(axis=0)
    stddevs_init[i] = np.std(Li_init, axis=0)
In [None]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
axs[0].scatter(centers[:, dim1], centers[:, dim2], c=range(10), s=50)
axs[0].scatter(centers_init[:, dim1], centers_init[:, dim2], c=range(10), s=50, marker=u's')

axs[1].scatter(centers[:, dim1], centers[:, dim2], c=range(10), s=50)
axs[1].scatter(centers[:, dim1] + stddevs[:, dim1], centers[:, dim2], c=range(10), s=50, marker=u'>')
axs[1].scatter(centers[:, dim1] - stddevs[:, dim1], centers[:, dim2], c=range(10), s=50, marker=u'<')
axs[1].scatter(centers[:, dim1], centers[:, dim2] + stddevs[:, dim2], c=range(10), s=50, marker=u'^')
axs[1].scatter(centers[:, dim1], centers[:, dim2] - stddevs[:, dim2], c=range(10), s=50, marker=u'v')

#axs[0].set_xlim(-1.2, 1.2)
#axs[0].set_ylim(-1.2, 1.2)
#axs[1].set_xlim(-1.2, 1.2)
#axs[1].set_ylim(-1.2, 1.2)

print (centers[:, dim1] - centers_init[:, dim1])
print (centers[:, dim2] - centers_init[:, dim2])
print (stddevs[:, dim1] - stddevs_init[:, dim1])
print (stddevs[:, dim2] - stddevs_init[:, dim2])
In [None]: